{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "18d2fbd4",
   "metadata": {},
   "source": [
    "# Learned drift detectors on Adult Census\n",
    "\n",
    "Under the hood, drift detectors leverage a function (also known as a test-statistic) that is expected to take a large value if drift has occurred and a low value if not. The power of the detector is partly determined by how well the function satisfies this property. However, specifying such a function in advance can be very difficult. \n",
    "\n",
    "## Detecting drift with a learned classifier\n",
    "\n",
    "The classifier-based drift detector simply tries to correctly distinguish instances from the reference data vs. the test set. The classifier is trained to output the probability that a given instance belongs to the test set. If the probabilities it assigns to unseen tests instances are significantly higher (as determined by a Kolmogorov-Smirnov test) than those it assigns to unseen reference instances then the test set must differ from the reference set and drift is flagged. To leverage all the available reference and test data, stratified cross-validation can be applied and the out-of-fold predictions are used for the significance test. Note that a new classifier is trained for each test set or even each fold within the test set.\n",
    "\n",
    "### Backend\n",
    "\n",
    "The method works with both the **PyTorch**, **TensorFlow**, and **Sklearn** frameworks. We will focus exclusively on the **Sklearn** backend in this notebook.\n",
    "\n",
    "### Dataset\n",
    "**Adult** dataset consists of 32,561 distributed over 2 classes based on whether the annual income is >50K. We evaluate drift on particular subsets of the data which are constructed based on the education level. As we will further discuss, our reference dataset will consist of people having a **low education** level, while our test dataset will consist of people having a **high education** level.\n",
    "\n",
    "**Note**: we need to install ``alibi`` to fetch the ``adult`` dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e863c394",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install alibi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "fd3559a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from typing import List, Tuple, Dict, Callable\n",
    "\n",
    "from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier\n",
    "from sklearn.svm import LinearSVC\n",
    "\n",
    "from sklearn.compose import ColumnTransformer\n",
    "from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from alibi.datasets import fetch_adult\n",
    "from alibi_detect.cd import ClassifierDrift"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0aac0471",
   "metadata": {},
   "source": [
    "### Load Adult Census Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8496e24e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# fetch adult dataset\n",
    "adult = fetch_adult()\n",
    "\n",
    "# separate columns in numerical and categorical.\n",
    "categorical_names = [adult.feature_names[i] for i in adult.category_map.keys()]\n",
    "categorical_ids = list(adult.category_map.keys())\n",
    "\n",
    "numerical_names = [name for i, name in enumerate(adult.feature_names) if i not in adult.category_map.keys()]\n",
    "numerical_ids = [i for i in range(len(adult.feature_names)) if i not in adult.category_map.keys()]\n",
    "\n",
    "X = adult.data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c30925b",
   "metadata": {},
   "source": [
    "We split the dataset in two based on the education level. We define a `low_education` level consisting of: `'Dropout'`, `'High School grad'`, `'Bachelors'`, and a `high_education` level consisting of: `'Bachelors'`, `'Masters'`, `'Doctorate'`. Intentionally we included an overlap between the two distributions consisting of people that have a `Bachelors` degree. Our goal is to detect that the two distributions are different."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b0ffdadb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['Associates', 'Bachelors', 'Doctorate', 'Dropout', 'High School grad', 'Masters', 'Prof-School']\n"
     ]
    }
   ],
   "source": [
    "education_col = adult.feature_names.index('Education')\n",
    "education = adult.category_map[education_col]\n",
    "print(education)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "979a0946",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Low education: ['Dropout', 'High School grad', 'Bachelors']\n",
      "High education: ['Bachelors', 'Masters', 'Doctorate']\n"
     ]
    }
   ],
   "source": [
    "# define low education\n",
    "low_education = [\n",
    "    education.index('Dropout'),\n",
    "    education.index('High School grad'),\n",
    "    education.index('Bachelors')\n",
    "    \n",
    "]\n",
    "# define high education\n",
    "high_education = [\n",
    "    education.index('Bachelors'),\n",
    "    education.index('Masters'),\n",
    "    education.index('Doctorate')\n",
    "]\n",
    "print(\"Low education:\", [education[i] for i in low_education])\n",
    "print(\"High education:\", [education[i] for i in high_education])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "2244a5ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "# select instances for low and high education\n",
    "low_education_mask = pd.Series(X[:, education_col]).isin(low_education).to_numpy()\n",
    "high_education_mask = pd.Series(X[:, education_col]).isin(high_education).to_numpy()\n",
    "X_low, X_high = X[low_education_mask], X[high_education_mask]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e58c66b",
   "metadata": {},
   "source": [
    "We sample our reference dataset from the `low_education` level. In addition, we sample two other datasets:\n",
    "\n",
    " * `x_h0` - sampled from the `low_education` level to support the null hypothesis (i.e., the two distributions are identical);\n",
    "\n",
    "* `x_h1` - sampled from the `high_education` level to support the alternative hypothesis (i.e., the two distributions are different);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "533c269f",
   "metadata": {},
   "outputs": [],
   "source": [
    "size = 1000\n",
    "np.random.seed(0)\n",
    "\n",
    "# define reference and H0 dataset\n",
    "idx_low = np.random.choice(np.arange(X_low.shape[0]), size=2*size, replace=False)\n",
    "x_ref, x_h0 = train_test_split(X_low[idx_low], test_size=0.5, random_state=5, shuffle=True)\n",
    "\n",
    "# define reference and H1 dataset\n",
    "idx_high = np.random.choice(np.arange(X_high.shape[0]), size=size, replace=False)\n",
    "x_h1 = X_high[idx_high]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18375488",
   "metadata": {},
   "source": [
    "### Define dataset pre-processor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "9da1141b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define numerical standard scaler.\n",
    "num_transf = StandardScaler()\n",
    "\n",
    "# define categorical one-hot encoder.\n",
    "cat_transf = OneHotEncoder(\n",
    "    categories=[range(len(x)) for x in adult.category_map.values()],\n",
    "    handle_unknown=\"ignore\"\n",
    ")\n",
    "\n",
    "# Define column transformer\n",
    "preprocessor = ColumnTransformer(\n",
    "    transformers=[\n",
    "        (\"cat\", cat_transf, categorical_ids),\n",
    "        (\"num\", num_transf, numerical_ids),\n",
    "    ],\n",
    "    sparse_threshold=0\n",
    ")\n",
    "\n",
    "# fit preprocessor.\n",
    "preprocessor = preprocessor.fit(np.concatenate([x_ref, x_h0, x_h1]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd28745a",
   "metadata": {},
   "source": [
    "### Utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "6ec08c0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = ['No!', 'Yes!']\n",
    "\n",
    "def print_preds(preds: dict, preds_name: str) -> None:\n",
    "    print(preds_name)\n",
    "    print('Drift? {}'.format(labels[preds['data']['is_drift']]))\n",
    "    print(f'p-value: {preds[\"data\"][\"p_val\"]:.3f}')\n",
    "    print('')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b08bb33",
   "metadata": {},
   "source": [
    "### Drift detection "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62d35a5d",
   "metadata": {},
   "source": [
    "We perform a **binomial** test using a `RandomForestClassifier`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "089cf9a9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Both `n_folds` and `train_size` specified. By default `n_folds` is used.\n",
      "`retrain_from_scratch=True` sets automatically the parameter `warm_start=False`.\n",
      "`use_oob=False` sets automatically the classifier parameters `oob_score=False`.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "H0\n",
      "Drift? No!\n",
      "p-value: 0.681\n",
      "\n",
      "H1\n",
      "Drift? Yes!\n",
      "p-value: 0.000\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# define classifier\n",
    "model = RandomForestClassifier()\n",
    "\n",
    "# define drift detector with binarize prediction\n",
    "detector = ClassifierDrift(\n",
    "    x_ref=x_ref,\n",
    "    model=model,\n",
    "    backend='sklearn',\n",
    "    preprocess_fn=preprocessor.transform,\n",
    "    binarize_preds=True,\n",
    "    n_folds=2,\n",
    ")\n",
    "\n",
    "# print results\n",
    "print_preds(detector.predict(x=x_h0), \"H0\")\n",
    "print_preds(detector.predict(x=x_h1), \"H1\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5182f338",
   "metadata": {},
   "source": [
    "As expected, when testing against `x_h0`, we fail to reject $H_0$, while for the second case there is enough evidence to reject $H_0$ and flag that the data has drifted.\n",
    "\n",
    "For the classifiers that do not support `predict_proba` but offer support for `decision_function`, we can perform a **K-S** test on the scores by setting `preds_type='scores'`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "1858161c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Both `n_folds` and `train_size` specified. By default `n_folds` is used.\n",
      "`retrain_from_scratch=True` sets automatically the parameter `warm_start=False`.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "H0\n",
      "Drift? No!\n",
      "p-value: 0.294\n",
      "\n",
      "H1\n",
      "Drift? Yes!\n",
      "p-value: 0.000\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# define model\n",
    "model = GradientBoostingClassifier()\n",
    "\n",
    "\n",
    "# define drift detector\n",
    "detector = ClassifierDrift(\n",
    "    x_ref=x_ref,\n",
    "    model=model,\n",
    "    backend='sklearn',\n",
    "    preprocess_fn=preprocessor.transform,\n",
    "    preds_type='scores',\n",
    "    binarize_preds=False,\n",
    "    n_folds=2,\n",
    ")\n",
    "\n",
    "# print results\n",
    "print_preds(detector.predict(x=x_h0), \"H0\")\n",
    "print_preds(detector.predict(x=x_h1), \"H1\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e7ee5c8",
   "metadata": {},
   "source": [
    "Some models can return a poor estimate of the class label probability or some might not even support probability predictions. We can add calibration on top of each classifier to obtain better probability estimates and perform a **K-S** test. For demonstrative purposes, we will calibrate a ``LinearSVC`` which does not support ``predict_proba``, but any other classifier would work."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "08d480e5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Both `n_folds` and `train_size` specified. By default `n_folds` is used.\n",
      "Using calibration to obtain the prediction probabilities.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "H0\n",
      "Drift? No!\n",
      "p-value: 0.457\n",
      "\n",
      "H1\n",
      "Drift? Yes!\n",
      "p-value: 0.000\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# define model - does not support predict_proba\n",
    "model = LinearSVC(max_iter=10000)\n",
    "\n",
    "# define drift detector\n",
    "detector = ClassifierDrift(\n",
    "    x_ref=x_ref,\n",
    "    model=model,\n",
    "    backend='sklearn',\n",
    "    preprocess_fn=preprocessor.transform,\n",
    "    binarize_preds=False,\n",
    "    n_folds=2,\n",
    "    use_calibration=True,\n",
    "    calibration_kwargs={'method': 'isotonic'}\n",
    ")\n",
    "\n",
    "# print results\n",
    "print_preds(detector.predict(x=x_h0), \"H0\")\n",
    "print_preds(detector.predict(x=x_h1), \"H1\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "450d3780",
   "metadata": {},
   "source": [
    "### Speeding things up"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "229f2cef",
   "metadata": {},
   "source": [
    "In order to use the entire dataset and obtain unbiased predictions required to perform the statistical test, the  `ClassifierDrift` detector has the option to perform a `n_folds` split. Although appealing due to its data efficiency, this method can be slow since it is required to train a number of `n_folds` classifiers. \n",
    "\n",
    "For the `RandomForestClassifier` we can avoid retraining `n_folds` classifiers by using the out-of-bag predictions. In a `RandomForestClassifier` each tree is trained on a separate dataset obtained by sampling with replacement the original training set, a method known as bagging. On average, only 63\\% unique samples from the original dataset are used to train each tree ([Bostrom](https://people.dsv.su.se/~henke/papers/bostrom08b.pdf)). Thus, for each tree, we can obtain predictions for the remaining out-of-bag samples (i.e., the rest of 37\\%). By cumulating the out-of-bag predictions across all the trees we can eventually obtain a prediction for each sample in the original dataset. Note that we used the word 'eventually' because if the number of trees is too small, covering the entire original dataset might be unlikely.\n",
    "\n",
    "For demonstrative purposes, we will compare the running time of the `ClassifierDrift` detector when using a `RandomForestClassifier` in two setups: `n_folds=5, use_oob=False` and `use_oob=True`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3844cd1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_estimators = 400\n",
    "n_folds = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "0037ec4e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Both `n_folds` and `train_size` specified. By default `n_folds` is used.\n",
      "`retrain_from_scratch=True` sets automatically the parameter `warm_start=False`.\n",
      "`use_oob=False` sets automatically the classifier parameters `oob_score=False`.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "H0\n",
      "Drift? No!\n",
      "p-value: 0.670\n",
      "\n",
      "H1\n",
      "Drift? Yes!\n",
      "p-value: 0.000\n",
      "\n",
      "CPU times: user 5.13 s, sys: 4.92 ms, total: 5.14 s\n",
      "Wall time: 5.13 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# define drift detector\n",
    "detector_rf = ClassifierDrift(\n",
    "    x_ref=x_ref,\n",
    "    model=RandomForestClassifier(n_estimators=n_estimators),\n",
    "    backend='sklearn',\n",
    "    preprocess_fn=preprocessor.transform,\n",
    "    binarize_preds=False,\n",
    "    n_folds=n_folds\n",
    ")\n",
    "\n",
    "# print results\n",
    "print_preds(detector_rf.predict(x=x_h0), \"H0\")\n",
    "print_preds(detector_rf.predict(x=x_h1), \"H1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "9d215353",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`retrain_from_scratch=True` sets automatically the parameter `warm_start=False`.\n",
      "`use_oob=True` sets automatically the classifier parameters `boostrap=True` and `oob_score=True`. `train_size` and `n_folds` are ignored when `use_oob=True`.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "H0\n",
      "Drift? No!\n",
      "p-value: 0.905\n",
      "\n",
      "H1\n",
      "Drift? Yes!\n",
      "p-value: 0.000\n",
      "\n",
      "CPU times: user 1.39 s, sys: 18.3 ms, total: 1.41 s\n",
      "Wall time: 1.41 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# define drift detector\n",
    "detector_rf_oob = ClassifierDrift(\n",
    "    x_ref=x_ref,\n",
    "    model=RandomForestClassifier(n_estimators=n_estimators),\n",
    "    backend='sklearn',\n",
    "    preprocess_fn=preprocessor.transform,\n",
    "    binarize_preds=False,\n",
    "    use_oob=True\n",
    ")\n",
    "\n",
    "# print results\n",
    "print_preds(detector_rf_oob.predict(x=x_h0), \"H0\")\n",
    "print_preds(detector_rf_oob.predict(x=x_h1), \"H1\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5112b19",
   "metadata": {},
   "source": [
    "We can observe that in this particular setting, using the out-of-bag prediction can speed up the procedure up to almost x4."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
